{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aJ_pmgxvGur9"
      },
      "source": [
        "# teenyGPT\n",
        "## An educational example of a GPT-kind of autoregressive (causal) language model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tsdc7fDp40rQ"
      },
      "source": [
        "## Introduction\n",
        "\n",
        "In this example, we are going to implement an autoregressive model (ARM) for text generation, dubbed teenyGPT. The model used here is based on the architecture of so called Generative Pretrained Transformers (GPTs):\n",
        "- [Radford, A., Narasimhan, K., Salimans, T. and Sutskever, I., 2018. Improving language understanding by generative pre-training](https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf),\n",
        "\n",
        "however, we will use a very simplified architecture.\n",
        "\n",
        "You can read more about ARMs in Chapter 2 of the following book:\n",
        "- [Tomczak, J.M., \"Deep Generative Modeling\", Springer, 2022](https://link.springer.com/book/10.1007/978-3-030-93158-2)\n",
        "\n",
        "You can read more about transformers in Chapter 12 of the following book:\n",
        "- [Prince, S.J.D., \"Understanding Deep Learning\", MIT Press, 2023](https://udlbook.github.io/udlbook/)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RvsuVNczG6pP"
      },
      "source": [
        "### Theory behind ARMs\n",
        "\n",
        "Let us consider a high-dimensional random variable $\\mathbf{x} \\in \\mathcal{X}^{T}$ where $\\mathcal{X} = \\{0,1,\\dots , L-1\\}$ or $\\mathcal{X} = \\mathbb{R}$. Our goal is to model $p(\\mathbf{x})$. We can apply the product rule to express this distribution as follows:\n",
        "$$\n",
        "p(\\mathbf{x}) = p(x_1) \\prod_{t=2}^{T} p(x_{t}|\\mathbf{x}_{<t}) ,\n",
        "$$\n",
        "where $\\mathbf{x}_{<t} = [x_1, x_2, \\ldots , x_{t-1}]^{\\top}$. For instance, for $\\mathbf{x} = [x_1, x_2, x_{3}]^{\\top}$, we have $p(\\mathbf{x}) = p(x_1) p(x_{2}|x_{1}) p(x_{3} | x_{1}, x_{2})$.\n",
        "\n",
        "The generative procedure is straightforward: We start with $x_1 \\sim p(x_1)$, and then we proceed with $x_t \\sim p(x_{t}|\\mathbf{x}_{<t})$ by plugging in all previously sampled variables $\\mathbf{x}_{<t}$. We can think of this procedure as a for-loop.\n",
        "\n",
        "Now, the main goal is how to parameterize conditional distributions $p(x_{d}|\\mathbf{x}_{<t})$. We can accomplish that by using neural networks, in particular, transformers. In this example, we focus on <i>decoder transformers</i> that utilize causal multi-head self-attention."
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Note\n",
        "\n",
        "In this example, we build a simple LLM model. For this purpose, we use a dataset consisting of $\\sim 8.5$k newspaper headlines, and each headline contain at most 150 letters (tokens). You are provided with a tokenizer for turning characters into a sequence of integers and padding, and text processing functions (e.g., removing special characters). Your model will be trained with 1.3M tokens per iteration, and depending on architecture choices, it can consist of thousands to few millions of weights.\n",
        "\n",
        "These numbers do not necessarilly impress anyone in the LLM community. However, please be aware that such datasets and models are not small and could be treated as a small-sized LLM-based problems. As you will notice in the end, we can still observe similar phenomena like hallucinations and the power of scaling up."
      ],
      "metadata": {
        "id": "sIaNyIwxSfN_"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# checking if CUDA is available; training on CPU could be painful...\n",
        "!nvidia-smi"
      ],
      "metadata": {
        "id": "W4e7A6qTO2tG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "suzhlbWqxtD9"
      },
      "source": [
        "## IMPORTS"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BjxkigYLxpB7"
      },
      "outputs": [],
      "source": [
        "# DO NOT REMOVE!\n",
        "import os\n",
        "\n",
        "import pickle\n",
        "\n",
        "import spacy\n",
        "import nltk\n",
        "from nltk.corpus import stopwords\n",
        "from nltk.stem import WordNetLemmatizer\n",
        "\n",
        "import numpy as np\n",
        "\n",
        "!pip install datasets\n",
        "from datasets import load_dataset\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import Dataset, DataLoader\n",
        "\n",
        "!pip install pytorch_model_summary\n",
        "from pytorch_model_summary import summary"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Cm23hRm6CqGh"
      },
      "outputs": [],
      "source": [
        "# Check if GPU is available and determine the device\n",
        "if torch.cuda.is_available():\n",
        "    device = 'cuda'\n",
        "else:\n",
        "    device = 'cpu'\n",
        "\n",
        "print(f'The available device is {device}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "81CxONpmMulC"
      },
      "outputs": [],
      "source": [
        "# mount drive: WE NEED IT FOR SAVING IMAGES! NECESSARY FOR GOOGLE COLAB!\n",
        "from google.colab import drive\n",
        "drive.mount('/content/gdrive')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OoPb92zNM4UY"
      },
      "outputs": [],
      "source": [
        "# PLEASE CHANGE IT TO YOUR OWN GOOGLE DRIVE OR YOUR LOCAL DIR!\n",
        "results_model_dir = '/content/gdrive/My Drive/Colab/Results/'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I3zs31tOyCmq"
      },
      "source": [
        "## Auxiliary classes and functions"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DF0agzL7tDHK"
      },
      "source": [
        "Let us define some useful classes:\n",
        "1. DataProcessor: \"cleaning\" texts.\n",
        "2. Tokenizer: transforming characters to integers and padding."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LIBNVRNJtHSd"
      },
      "outputs": [],
      "source": [
        "class DataProcessor(object):\n",
        "    def __init__(self, ):\n",
        "        super().__init__()\n",
        "        nlp = spacy.load(\"en_core_web_sm\")\n",
        "        nltk.download('omw-1.4')\n",
        "        nltk.download(\"punkt\")\n",
        "        nltk.download(\"wordnet\")\n",
        "        nltk.download(\"stopwords\")\n",
        "\n",
        "    @staticmethod\n",
        "    def preprocess_text(text):\n",
        "        # Tokenize, remove punctuation and lowercase\n",
        "        tokens = nltk.word_tokenize(text)\n",
        "        tokens = [word.lower() for word in tokens if word.isalpha()]\n",
        "\n",
        "        # Remove stopwords and lemmatize\n",
        "        stop_words = set(stopwords.words(\"english\"))\n",
        "        lemmatizer = WordNetLemmatizer()\n",
        "        processed_text = [\n",
        "            lemmatizer.lemmatize(word) for word in tokens if word not in stop_words\n",
        "        ]\n",
        "\n",
        "        return \" \".join(processed_text)\n",
        "\n",
        "    def process_batch(self, texts):\n",
        "        return [self.preprocess_text(d) for d in texts]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4NjoCwjV3TN7"
      },
      "outputs": [],
      "source": [
        "class Tokenizer(object):\n",
        "    def __init__(self, max_length=0):\n",
        "        super().__init__()\n",
        "\n",
        "        self.max_length = max_length\n",
        "\n",
        "        self.alphabet_letters = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']\n",
        "\n",
        "        self.alphabet = self.prepare_alphabet()\n",
        "        self.decoded_alphabet = self.prepare_decoded_alphabet()\n",
        "\n",
        "    def prepare_alphabet(self):\n",
        "        # PREPARE THE ALPHABET (CHAR->INT)\n",
        "        # as a dictionary\n",
        "        alphabet = {}\n",
        "        alphabet['pad'] = 0  # add 'pad'\n",
        "        count = 1\n",
        "\n",
        "        for letter in self.alphabet_letters:\n",
        "            alphabet[letter] = count\n",
        "            count += 1\n",
        "\n",
        "        # add 'pad', 'bos', 'eos' tokens\n",
        "        alphabet[' '] = count\n",
        "        alphabet['cls'] = count + 1\n",
        "\n",
        "        return alphabet\n",
        "\n",
        "    def prepare_decoded_alphabet(self):\n",
        "        # PREPARE DECODED ALPHABET (INT->CHAR)\n",
        "        decoded_alphabet_ints = [i for i in range(len(self.alphabet_letters))]\n",
        "\n",
        "        decoded_alphabet = {}\n",
        "        decoded_alphabet[0] = 'pad'\n",
        "\n",
        "        for i in decoded_alphabet_ints:\n",
        "            decoded_alphabet[i+1] = self.alphabet_letters[i]\n",
        "\n",
        "            decoded_alphabet[i+2] = ' '\n",
        "        decoded_alphabet[i+3] = 'cls'\n",
        "\n",
        "        return decoded_alphabet\n",
        "\n",
        "    def encode(self, texts):\n",
        "        N = len(texts)\n",
        "\n",
        "        if self.max_length == 0:\n",
        "            max_length = 0\n",
        "            for i in range(N):\n",
        "                len_i = len(texts[i])\n",
        "                if len_i > max_length:\n",
        "                    max_length = len_i\n",
        "        else:\n",
        "            max_length = self.max_length\n",
        "\n",
        "        tokens = np.zeros((N, max_length+1))\n",
        "\n",
        "        for i in range(N):\n",
        "            len_i = len(texts[i])\n",
        "            for j in range(-1, max_length):\n",
        "                if j == -1:\n",
        "                    tokens[i,j+1] = self.alphabet['cls']\n",
        "                elif j >= len_i:\n",
        "                    tokens[i,j+1] = self.alphabet['pad']\n",
        "                else:\n",
        "                    if texts[i][j] == 'é':\n",
        "                        tokens[i,j+1] = self.alphabet['e']\n",
        "                    elif texts[i][j] == 'í':\n",
        "                        tokens[i,j+1] = self.alphabet['e']\n",
        "                    elif texts[i][j] == 'á':\n",
        "                        tokens[i,j+1] = self.alphabet['a']\n",
        "                    elif texts[i][j] == 'ó':\n",
        "                        tokens[i,j+1] = self.alphabet['o']\n",
        "                    elif texts[i][j] == 'æ':\n",
        "                        tokens[i,j+1] = self.alphabet['a']\n",
        "                    elif texts[i][j] == 'ä':\n",
        "                        tokens[i,j+1] = self.alphabet['a']\n",
        "                    else:\n",
        "                        tokens[i,j+1] = self.alphabet[texts[i][j]]\n",
        "\n",
        "        return tokens\n",
        "\n",
        "    def decode(self, tokens):\n",
        "        texts = []\n",
        "\n",
        "        for i in range(len(tokens)):\n",
        "            tokens_i = tokens[i,:]\n",
        "            text_i = ''\n",
        "            for j in range(len(tokens_i)):\n",
        "                if tokens_i[j] == 0:\n",
        "                    break\n",
        "                else:\n",
        "                    if self.decoded_alphabet[tokens_i[j]] != 'cls':\n",
        "                        text_i += self.decoded_alphabet[tokens_i[j]]\n",
        "            texts.append(text_i)\n",
        "\n",
        "        return texts"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Some useful functions:"
      ],
      "metadata": {
        "id": "VhiWi-j3mELC"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def save_texts(sampled_texts, name=''):\n",
        "    # open file in write mode\n",
        "    with open(results_dir + '/samples_' + name + '.txt', 'w') as fp:\n",
        "        for item in sampled_texts:\n",
        "            # write each item in a new line\n",
        "            fp.write(\"%s\\n\" % item)"
      ],
      "metadata": {
        "id": "MoB-RuczmGlT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q60onqQR3TN8"
      },
      "source": [
        "# Data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0K9XaWO_3TN8"
      },
      "outputs": [],
      "source": [
        "class Headers(Dataset):\n",
        "    \"\"\"A simple dataset based on headers. Source: https://huggingface.co/datasets/IlyaGusev/headline_cause\"\"\"\n",
        "\n",
        "    def __init__(self, dataprocessor, tokenizer, mode='train', num_training_data=None, transforms=None):\n",
        "        # LOAD DATA\n",
        "        dataset = load_dataset(\"IlyaGusev/headline_cause\", \"en_simple\")\n",
        "\n",
        "        # PREPARE DATA\n",
        "        if mode == 'train':\n",
        "            train_texts = dataprocessor.process_batch(dataset['train'][:]['left_title'] + dataset['train'][:]['right_title']) # list\n",
        "            if num_training_data is None:\n",
        "                self.data = torch.from_numpy(tokenizer.encode(train_texts)).long()\n",
        "            else:\n",
        "                self.data = torch.from_numpy(tokenizer.encode(train_texts))[:num_training_data].long()\n",
        "        elif mode == 'val':\n",
        "            validation_texts = dataprocessor.process_batch(dataset['validation'][:]['left_title'] + dataset['validation'][:]['right_title']) # list\n",
        "            self.data = torch.from_numpy(tokenizer.encode(validation_texts)).long()\n",
        "        else:\n",
        "            test_texts = dataprocessor.process_batch(dataset['test'][:]['left_title'] + dataset['test'][:]['right_title']) # list\n",
        "            self.data = torch.from_numpy(tokenizer.encode(test_texts)).long()\n",
        "\n",
        "        self.transforms = transforms\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.data)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        sample = self.data[idx]\n",
        "        if self.transforms:\n",
        "            sample = self.transforms(sample)\n",
        "        return sample"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q2LLOs0kn7iw"
      },
      "source": [
        "## Implementing ARMs with Transformers\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7cXhOwKAzW6Z"
      },
      "source": [
        "### Loss Function (NLL)\n",
        "Our loss function is the negative log-likelihood for the categorical distribution (i.e., the cross-entropy loss).\n",
        "\n",
        "Please note how it is implemented and how tokens (T) are handled."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MrwQXSuEoFfH"
      },
      "outputs": [],
      "source": [
        "class LossFun(nn.Module):\n",
        "    def __init__(self,):\n",
        "        super().__init__()\n",
        "\n",
        "        self.loss = nn.NLLLoss(reduction='none')\n",
        "\n",
        "    def forward(self, y_model, y_true, reduction='sum'):\n",
        "        # y_model: B(atch) x T(okens) x V(alues)\n",
        "        # y_true: B x T\n",
        "        B, T, V = y_model.size()\n",
        "\n",
        "        y_model = y_model.view(B * T, V)\n",
        "        y_true = y_true.view(B * T,)\n",
        "\n",
        "        loss_matrix = self.loss(y_model, y_true) # B*T\n",
        "\n",
        "        if reduction == 'sum':\n",
        "            return torch.sum(loss_matrix)\n",
        "        elif reduction == 'mean':\n",
        "            loss_matrix = loss_matrix.view(B, T)\n",
        "            return torch.mean(torch.sum(loss_matrix, 1))\n",
        "        else:\n",
        "            raise ValueError('Reduction could be either `sum` or `mean`.')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nhNTy5mn0XDT"
      },
      "source": [
        "### Transformer block\n",
        "\n",
        "Transformers consist of transformer block. In the cell below, please define a transformer block."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9vTmKHwrpUVa"
      },
      "outputs": [],
      "source": [
        "class TransformerBlock(nn.Module):\n",
        "    def __init__(self, num_emb, num_neurons, num_heads=4):\n",
        "        super().__init__()\n",
        "\n",
        "        # hyperparams\n",
        "        self.D = num_emb\n",
        "        self.H = num_heads\n",
        "        self.neurons = num_neurons\n",
        "\n",
        "        # components\n",
        "        self.msha = nn.MultiheadAttention(embed_dim=self.D, num_heads=self.H, batch_first=True)\n",
        "        self.layer_norm1 = nn.LayerNorm(self.D)\n",
        "        self.layer_norm2 = nn.LayerNorm(self.D)\n",
        "\n",
        "        self.mlp = nn.Sequential(nn.Linear(self.D, self.neurons * self.D),\n",
        "                                nn.GELU(),\n",
        "                                nn.Linear(self.neurons * self.D, self.D))\n",
        "\n",
        "    def forward(self, x, causal=True):\n",
        "        # Multi-Head Self-Attention\n",
        "        x_attn, _ = self.msha(x, x, x, is_causal=causal, attn_mask=torch.empty(1,1), need_weights=False)\n",
        "        # LayerNorm\n",
        "        x = self.layer_norm1(x_attn + x)\n",
        "        # MLP\n",
        "        x_mlp = self.mlp(x)\n",
        "        # LayerNorm\n",
        "        x = self.layer_norm2(x_mlp + x)\n",
        "\n",
        "        return x"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CLIEwIiw00op"
      },
      "source": [
        "### teenyGPT (Decoder-Transformer)\n",
        "\n",
        "Once we have a class for transformer blocks, we need to define a decoder-transformer that defines an auto-regressive model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xQIvee5Cp69V"
      },
      "outputs": [],
      "source": [
        "class teenyGPT(nn.Module):\n",
        "    def __init__(self, num_tokens, num_token_vals, num_emb, num_neurons, num_heads=2, dropout_prob=0.1, num_blocks=10, device='cpu'):\n",
        "        super().__init__()\n",
        "\n",
        "        # Remember, always credit the author, even if it's you ;)\n",
        "        print('teenyGPT by JT.')\n",
        "\n",
        "        # hyperparams\n",
        "        self.device = device\n",
        "        self.num_tokens = num_tokens\n",
        "        self.num_token_vals = num_token_vals\n",
        "        self.num_emb = num_emb\n",
        "        self.num_blocks = num_blocks\n",
        "\n",
        "        # embedding layer\n",
        "        self.embedding = torch.nn.Embedding(num_token_vals, num_emb)\n",
        "\n",
        "        # positional embedding\n",
        "        self.positional_embedding = nn.Embedding(num_tokens, num_emb)\n",
        "\n",
        "        # transformer blocks\n",
        "        self.transformer_blocks = nn.ModuleList()\n",
        "        for _ in range(num_blocks):\n",
        "            self.transformer_blocks.append(TransformerBlock(num_emb=num_emb, num_neurons=num_neurons, num_heads=num_heads))\n",
        "\n",
        "        # output layer (logits + softmax)\n",
        "        self.logits = nn.Sequential(nn.Linear(num_emb, num_token_vals))\n",
        "\n",
        "        # dropout layer\n",
        "        self.dropout = nn.Dropout(dropout_prob)\n",
        "\n",
        "        # loss function\n",
        "        self.loss_fun = LossFun()\n",
        "\n",
        "    def transformer_forward(self, x, causal=True, temperature=1.0):\n",
        "        # x: B(atch) x T(okens)\n",
        "        # embedding of tokens\n",
        "        x = self.embedding(x) # B x T x D\n",
        "        # embedding of positions\n",
        "        pos = torch.arange(0, x.shape[1], dtype=torch.long).unsqueeze(0).to(self.device)\n",
        "        pos_emb = self.positional_embedding(pos)\n",
        "        # dropout of embedding of inputs\n",
        "        x = self.dropout(x + pos_emb)\n",
        "\n",
        "        # transformer blocks\n",
        "        for i in range(self.num_blocks):\n",
        "            x = self.transformer_blocks[i](x)\n",
        "\n",
        "        # output logits\n",
        "        out = self.logits(x)\n",
        "\n",
        "        return F.log_softmax(out/temperature, 2)\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def sample(self, batch_size=4, temperature=1.0):\n",
        "        x_seq = np.asarray([[self.num_token_vals - 1] for i in range(batch_size)])\n",
        "\n",
        "        # sample next tokens\n",
        "        for i in range(self.num_tokens-1):\n",
        "            xx = torch.tensor(x_seq, dtype=torch.long, device=self.device)\n",
        "            # process x and calculate log_softmax\n",
        "            x_log_probs = self.transformer_forward(xx, temperature=temperature)\n",
        "            # sample i-th tokens\n",
        "            x_i_sample = torch.multinomial(torch.exp(x_log_probs[:,i]), 1).to(self.device)\n",
        "            # update the batch with new samples\n",
        "            x_seq = np.concatenate((x_seq, x_i_sample.to('cpu').detach().numpy()), 1)\n",
        "\n",
        "        return x_seq\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def top1_rec(self, x, causal=True):\n",
        "        x_prob = torch.exp(self.transformer_forward(x, causal=True))[:,:-1,:].contiguous()\n",
        "        _, x_rec_max = torch.max(x_prob, dim=2)\n",
        "        return torch.sum(torch.mean((x_rec_max.float() == x[:,1:].float().to(device)).float(), 1).float())\n",
        "\n",
        "    def forward(self, x, causal=True, temperature=1.0, reduction='mean'):\n",
        "        # get log-probabilities\n",
        "        log_prob = self.transformer_forward(x, causal=causal, temperature=temperature)\n",
        "\n",
        "        return self.loss_fun(log_prob[:,:-1].contiguous(), x[:,1:].contiguous(), reduction=reduction)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hLhgze7DA4yx"
      },
      "source": [
        "### Evaluation and training functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I9Dr3a6lqJ0W"
      },
      "outputs": [],
      "source": [
        "def evaluation(test_loader, name=None, model_best=None, epoch=None, device='cuda'):\n",
        "    # EVALUATION\n",
        "    if model_best is None:\n",
        "        # load best performing model\n",
        "        model_best = torch.load(name + '.model').to(device)\n",
        "\n",
        "    model_best.eval()\n",
        "    loss = 0.\n",
        "    rec = 1.\n",
        "    N = 0.\n",
        "    for indx_batch, test_batch in enumerate(test_loader):\n",
        "        loss_t = model_best.forward(test_batch.to(device), reduction='sum')\n",
        "        loss = loss + loss_t.item()\n",
        "\n",
        "        rec_t = model_best.top1_rec(test_batch.to(device))\n",
        "        rec = rec + rec_t.item()\n",
        "\n",
        "        N = N + test_batch.shape[0]\n",
        "    loss = loss / N\n",
        "    rec = rec / N\n",
        "\n",
        "    if epoch is None:\n",
        "        print(f'FINAL LOSS: nll={loss}, rec={rec}')\n",
        "    else:\n",
        "        print(f'Epoch: {epoch}, val nll={loss}, val rec={rec}')\n",
        "\n",
        "    return loss, rec\n",
        "\n",
        "def plot_curve(name, nll_val, ylabel='nll'):\n",
        "    plt.plot(np.arange(len(nll_val)), nll_val, linewidth='3')\n",
        "    plt.xlabel('epochs')\n",
        "    plt.ylabel(ylabel)\n",
        "    plt.savefig(name + '_' + ylabel + '_val_curve.pdf', bbox_inches='tight')\n",
        "    plt.close()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9ABgMeG0qFAP"
      },
      "outputs": [],
      "source": [
        "def training(name, max_patience, num_epochs, model, optimizer, training_loader, val_loader, device='cuda'):\n",
        "    nll_val = []\n",
        "    rec_val = []\n",
        "    best_nll = 1000.\n",
        "    patience = 0\n",
        "\n",
        "    # Main loop\n",
        "    for e in range(num_epochs):\n",
        "        # TRAINING\n",
        "        model.train()\n",
        "        for indx_batch, batch in enumerate(training_loader):\n",
        "            loss = model.forward(batch.to(device))\n",
        "\n",
        "            optimizer.zero_grad()\n",
        "            loss.backward(retain_graph=True)\n",
        "            optimizer.step()\n",
        "\n",
        "        # Validation\n",
        "        loss_val, r_val = evaluation(val_loader, model_best=model, epoch=e, device=device)\n",
        "        nll_val.append(loss_val)  # save for plotting\n",
        "        rec_val.append(r_val)\n",
        "\n",
        "        if e == 0:\n",
        "            print('saved!')\n",
        "            torch.save(model, name + '.model')\n",
        "            best_nll = loss_val\n",
        "\n",
        "            sampled_tokens = model.sample(batch_size=64, temperature=1.0)\n",
        "            sampled_texts = tokenizer.decode(sampled_tokens)\n",
        "            save_texts(sampled_texts, name='epoch_' + str(e))\n",
        "\n",
        "        else:\n",
        "            if loss_val < best_nll:\n",
        "                print('saved!')\n",
        "                torch.save(model, name + '.model')\n",
        "                best_nll = loss_val\n",
        "                patience = 0\n",
        "\n",
        "                sampled_tokens = model.sample(batch_size=64, temperature=1.0)\n",
        "                sampled_texts = tokenizer.decode(sampled_tokens)\n",
        "                save_texts(sampled_texts, name='epoch_' + str(e))\n",
        "            else:\n",
        "                patience = patience + 1\n",
        "\n",
        "        if patience > max_patience:\n",
        "            break\n",
        "\n",
        "    nll_val = np.asarray(nll_val)\n",
        "    rec_val = np.asarray(rec_val)\n",
        "\n",
        "    np.save(name + '_nll_val.npy', nll_val)\n",
        "    np.save(name + '_rec_val.npy', rec_val)\n",
        "\n",
        "    return nll_val, rec_val"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kWr8N2u2qNTu"
      },
      "source": [
        "### Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bFTE5jtYpxDV"
      },
      "outputs": [],
      "source": [
        "dataprocessor = DataProcessor()\n",
        "tokenizer = Tokenizer(max_length=149)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UXHitzrYqNhY"
      },
      "outputs": [],
      "source": [
        "num_training_data = None  # None to take all training data\n",
        "\n",
        "#-dataset\n",
        "train_dataset = Headers(dataprocessor, tokenizer, num_training_data=num_training_data, mode=\"train\")\n",
        "validation_dataset = Headers(dataprocessor, tokenizer, mode=\"val\")\n",
        "test_dataset = Headers(dataprocessor, tokenizer, mode=\"test\")\n",
        "\n",
        "#-dataloaders\n",
        "BATCH_SIZE = 32\n",
        "\n",
        "training_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
        "val_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False)\n",
        "test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#-creating a dir for saving results\n",
        "name = 'teenyGPT'\n",
        "results_dir = results_model_dir + name + '/'\n",
        "if not(os.path.exists(results_dir)):\n",
        "  os.mkdir(results_dir)"
      ],
      "metadata": {
        "id": "OTBFl-bNb0s9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kmKDXMI0B231"
      },
      "source": [
        "In the next cell, please initialize the model. Please remember about commenting your code!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b73aaBDxqSYb"
      },
      "outputs": [],
      "source": [
        "num_tokens = 150 # do not modify!\n",
        "num_token_vals = 29  # do not modify!\n",
        "num_neurons = 512\n",
        "num_heads = 8\n",
        "num_blocks = 4\n",
        "num_emb = num_heads * 4\n",
        "causal=True # do not modify!\n",
        "\n",
        "lr = 1e-3 # learning rate; do not modify!\n",
        "num_epochs = 1000 # max. number of epochs; do not modify!\n",
        "max_patience = 5 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped; do not modify!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JQy_iNJs3TOD"
      },
      "outputs": [],
      "source": [
        "model = teenyGPT(num_tokens=num_tokens, num_token_vals=num_token_vals, num_emb=num_emb, num_neurons=num_neurons, num_heads=num_heads, num_blocks=num_blocks, device=device)\n",
        "model = model.to(device)\n",
        "# Print the summary (like in Keras)\n",
        "print(summary(model, torch.zeros(1, num_tokens, dtype=torch.long).to(device), show_input=False, show_hierarchical=False))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iC8AkWt4CURT"
      },
      "source": [
        "Initialize the optimizer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a3nTSDe7ql08"
      },
      "outputs": [],
      "source": [
        "optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad == True], lr=lr)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P5GrzUcHFweG"
      },
      "source": [
        "### Training and final evaluation\n",
        "\n",
        "In the following two cells, we run the training and the final evaluation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VD7WuY6bqnBK"
      },
      "outputs": [],
      "source": [
        "# Training procedure\n",
        "nll_val, rec_val = training(name=results_dir + name, max_patience=max_patience, num_epochs=num_epochs, model=model, optimizer=optimizer, training_loader=training_loader, val_loader=val_loader, device=device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JAuMt9_wquOI"
      },
      "outputs": [],
      "source": [
        "# Final evaluation\n",
        "test_loss, test_rec = evaluation(name=results_dir + name, test_loader=test_loader, device=device)\n",
        "\n",
        "with open(results_dir + name + '_test_loss.txt', \"w\") as f:\n",
        "    f.write('Test NLL: ' + str(test_loss)+'\\n'+'Test REC: ' + str(test_rec))\n",
        "    f.close()\n",
        "\n",
        "plot_curve(results_dir + name, nll_val, ylabel='nll')\n",
        "plot_curve(results_dir + name, rec_val, ylabel='rec')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LcQ0GPZy3TOE"
      },
      "outputs": [],
      "source": [
        "# Sample texts: load best model\n",
        "model_best = torch.load(results_dir + name + '.model')\n",
        "model_best = model_best.eval()\n",
        "\n",
        "# sample\n",
        "temperature = 1.0 # you can modify it\n",
        "num_samples = 64 # you can modify it\n",
        "\n",
        "sampled_tokens = model_best.sample(batch_size=num_samples, temperature=temperature)  # do not modify\n",
        "sampled_texts = tokenizer.decode(sampled_tokens)  # do not modify\n",
        "\n",
        "save_texts(sampled_texts, name='FINAL_' + str(temperature))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}